/* SPDX-License-Identifier: GPL-2.0 */
/*
 * Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 and
 * only version 2 as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 */

#include <linux/kthread.h>
#include <linux/module.h>
#include <linux/net.h>
#include <linux/inet.h>
#include <linux/in.h>
#include <linux/vm_sockets.h>
#include <linux/kprobes.h>
#include <net/vsock_addr.h>
#include <net/sock.h>
#include <net/tcp.h>
#include <uapi/asm-generic/errno.h>

char cid_ip_map[100] = {0};
unsigned int cids[10] = {0};
// IP strings 255.255.255.255 length won't exceed 16
char ips[10][16] = {0};
typedef unsigned long (*kallsyms_lookup_name_t)(const char *name);

#define vs_info(fmt, ...) 	\
(								\
{								\
	pr_info("[%s::%s:%4d] " fmt "\n",	\
		KBUILD_MODNAME, kbasename(__FILE__), __LINE__, ##__VA_ARGS__);	\
}								\
)

#define vs_err(fmt, ...) 	\
(								\
	{							\
	pr_err("[%s::%s:%4d] " fmt "\n",	\
		KBUILD_MODNAME, kbasename(__FILE__), __LINE__, ##__VA_ARGS__);	\
}								\
)

extern const struct proto_ops inet_stream_ops;
extern const struct proto_ops inet_dgram_ops;
extern int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
extern int inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags);
extern int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags);
static struct kprobe kp = {
	.symbol_name = "kallsyms_lookup_name"
};

kallsyms_lookup_name_t vs_kallsyms_lookup_name;
static struct net_proto_family __rcu **vs_net_families;
struct net_proto_family *inet_pf;
struct net_proto_family *vsock_pf;
struct proto_ops vs_stream_ops;
struct proto_ops vs_dgram_ops;
int cid_index = 0;

void do_parse_cid_ip_map(char *str)
{
	char *m = NULL;
	char *cidstr = str;
	char *ip = NULL;
	unsigned int cid;
	u32 s_addr;

	m = strstr(str, "-");
	if (m == NULL) {
		vs_err("input cid_ip_map %s not valid", str);
		return;
	}
	if (cid_index >= 10) {
		vs_err("cid index exceed upper limit");
		return;
	}
	ip = m + 1;
	*m = '\0';
	if (kstrtouint(cidstr, 10, &cid)) {
		vs_err("input cid:%s not a number", cidstr);
		return;
	}
	if (in4_pton(ip, strlen(ip), (u8 *)&s_addr, -1, NULL) != 1) {
		vs_err("input ip:%s not valid", ip);
		return;
	}
	cids[cid_index] = cid;
	strlcpy(ips[cid_index], ip, strlen(ip) + 1);
	cid_index++;
	return;
}

void parse_cid_ip_maps(void)
{
	char *s = NULL;
	char *dot = NULL;

	if (strlen(cid_ip_map) == 0) {
		vs_err("cid_ip_map not assigned");
		return;
	}
	if (strlen(cid_ip_map) >= 99) {
		vs_err("input cid_ip_map too long, just cut it, pls check to make sure it works");
		cid_ip_map[99]='\0';
	}

	s = cid_ip_map;
	while(1) {
		dot = strstr(s, ",");
		if (dot == NULL) {
			do_parse_cid_ip_map(s);
			break;
		}
		*dot = '\0';
		do_parse_cid_ip_map(s);
		s = dot + 1;
	}
	return;
}

int get_cid_index(unsigned int cid)
{
	int i = 0;
	for (i = 0; i < cid_index; i++) {
		if (cids[i] == cid)
			return i;
	}
	return -1;
}

static int vm_to_in(struct sockaddr_in *in_addr, struct sockaddr *addr)
{
	struct sockaddr_vm *vm_addr = NULL;
	char *ip = NULL;
	int cid = 0;
	int index = -1;
	int ret = 0;

	ret = vsock_addr_cast(addr, sizeof(struct sockaddr_vm), &vm_addr);
	if (!vm_addr) {
		vs_err("vsock_addr_cast failed, vm_addr:0x%llx, get ret:%d\n", (u64)vm_addr, ret);
		return ret;
	}
	cid = vm_addr->svm_cid;
	index = get_cid_index(cid);
	if (index < 0) {
		vs_err("index of cid:%d not found", cid);
		return -EINVAL;
	}
	ip = ips[index];
	vs_err("vsock use ip:%s", ip);
	in_addr->sin_family = AF_INET;
	in_addr->sin_port = htons((u16)vm_addr->svm_port);
	in4_pton(ip, strlen(ip), (u8 *)&in_addr->sin_addr.s_addr, -1, NULL);

	return 0;
}

static int vs_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
{
	int err;
	struct sockaddr_in in_addr;

	err = vm_to_in(&in_addr, addr);
	if (err) {
		vs_err("failed to convert vm_addr to in_addr");
		return err;
	}

	vs_err("call inet_bind, family:%d, port:%d, addr:%x", in_addr.sin_family, in_addr.sin_port, in_addr.sin_addr.s_addr);
	return inet_bind(sock, (struct sockaddr *)&in_addr, sizeof(in_addr));
}

int vs_stream_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags)
{
	int err;
	struct sockaddr_in in_addr;

	err = vm_to_in(&in_addr, uaddr);
	if (err) {
		vs_err("failed to convert vm_addr to in_addr");
		return err;
	}
	vs_err("call inet_stream_connect, family:%d, port:%d, addr:%x", in_addr.sin_family, in_addr.sin_port, in_addr.sin_addr.s_addr);
	err = inet_stream_connect(sock, (struct sockaddr *)&in_addr, sizeof(in_addr), flags);
	vs_err("inet_stream_connect get ret:%d", err);
	return err;
}

int vs_dgram_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags)
{
	int err;
	struct sockaddr_in in_addr;

	err = vm_to_in(&in_addr, uaddr);
	if (err) {
		vs_err("failed to convert vm_addr to in_addr");
		return err;
	}
	return inet_dgram_connect(sock, (struct sockaddr *)&in_addr, sizeof(in_addr), flags);
}

int kallsyms_hack_init(void)
{
	int ret = register_kprobe(&kp);
	if (ret < 0) {
		vs_err("register kprobe failed, Please confirm whether kprobe is enabled, ret:%d", ret);
		return -1;
	}
	vs_kallsyms_lookup_name = (kallsyms_lookup_name_t) kp.addr;
	unregister_kprobe(&kp);
	if (vs_kallsyms_lookup_name == NULL) {
		vs_err("get kallsyms function by kprobe failed.");
		return -1;
	}

	vs_net_families = (struct net_proto_family __rcu **)vs_kallsyms_lookup_name("net_families");
	rcu_read_lock();
	inet_pf = rcu_dereference(vs_net_families[AF_INET]);
	vsock_pf = rcu_dereference(vs_net_families[AF_VSOCK]);
	rcu_read_unlock();

	memcpy(&vs_stream_ops, &inet_stream_ops, sizeof(struct proto_ops));
	vs_stream_ops.owner = THIS_MODULE;
	vs_stream_ops.bind = vs_bind;
	vs_stream_ops.connect = vs_stream_connect;

	memcpy(&vs_dgram_ops, &inet_dgram_ops, sizeof(struct proto_ops));
	vs_dgram_ops.owner = THIS_MODULE;
	vs_dgram_ops.bind = vs_bind;
	vs_dgram_ops.connect = vs_dgram_connect;

	return 0;
}

void kallsyms_hack_fini(void)
{
	return;
}

static int vsock_sim_create(struct net *net, struct socket *sock, int protocol, int kern)
{
	if (!sock)
		return -EINVAL;
	vs_err("in vsock_sim_create protocol:%d", protocol);
	// protocol should be tamperd here
	protocol = IPPROTO_IP;
	inet_pf->create(net, sock, protocol, kern);

	// update sock ops
	switch (sock->type) {
	case SOCK_DGRAM:
		sock->ops = &vs_dgram_ops;
		break;
	case SOCK_STREAM:
		sock->ops = &vs_stream_ops;
		break;
	default:
		return -ESOCKTNOSUPPORT;
	}

	return 0;
}

static const struct net_proto_family vsock_sim_family_ops = {
	.family = AF_VSOCK,
	.create = vsock_sim_create,
	.owner = THIS_MODULE,
};

static int __init vsock_sim_init(void)
{
	int ret = 0;
	int i = 0;

	kallsyms_hack_init();
	parse_cid_ip_maps();
	vs_err("get cid_ip map:");
	for (i = 0; i < cid_index; i++) {
		vs_err("\t %d : %s", cids[i], ips[i]);
	}

	sock_unregister(AF_VSOCK);
	ret = sock_register(&vsock_sim_family_ops);
	if (ret) {
		vs_err("could not register vsock_sim_family_ops: %d", ret);
		// proto_unregister(&vsock_sim_proto);
		return -EINVAL;
	}

	return ret;
}

static void __exit vsock_sim_exit(void)
{
	sock_unregister(AF_VSOCK);
	sock_register(vsock_pf);
	kallsyms_hack_fini();
	return;
}

module_param_string(cid_ip_map, cid_ip_map, sizeof(cid_ip_map), 0600);
module_init(vsock_sim_init);
module_exit(vsock_sim_exit);
MODULE_AUTHOR("dengguangxing@huawei.com");
MODULE_LICENSE("GPL");
